import torch
import random
from torch.utils.data import DataLoader
import ShiftingWindowSetting as sw


class AGEM(sw.CLLearningAlgo):

    mem = []  # stores a list of previous tasks data with the null classes of the task attached
    mem_grad = {}

    def __init__(self, args, mem_samples_per_task=13):
        super().__init__(args=args)
        self.mem_samples_per_task = mem_samples_per_task

    def _calc_mem_grad(self):
        if len(self.mem) == 0:
            return

        self.model.eval()
        self.model.zero_grad()
        for nullClasses, task_mem in self.mem:
            #data_loader = DataLoader(task_mem, batch_size=self.batch_size)
            data_loader = DataLoader(task_mem, batch_size=1000)
            for X, Y in data_loader:
                X, Y = X.to(self.device), Y.to(self.device)
                loss = self.loss_fn(sw.calc_model_output(self.model, X, nullClasses), Y)
                loss.backward()
        #X = []
        #Y = []
        #for _, task_mem in self.mem:
        #    for data in task_mem:
        #        X.append(data[0])
        #        Y.append(data[1])
        #X = torch.stack(X)
        #Y = torch.tensor(Y)
        #X, Y = X.to(self.device), Y.to(self.device)
        #out = self.model(X)
        #i = 0
        #loss = torch.zeros(1, requires_grad=True, device=self.device)
        #for nullClasses, task_mem in self.mem:
        #    task_out = out[i:i+len(task_mem)]
        #    task_Y = Y[i:i+len(task_mem)]
        #    task_out[:, nullClasses] = -10e10
        #    loss = loss + self.loss_fn(task_out, task_Y)
        #    i += len(task_mem)
        #loss.backward()
        self.mem_grad = self._get_grad()
        self.model.zero_grad()
        self.model.train()

    # assumes that memory is at least has len of batch_size
    def _grad_project(self):
        if len(self.mem) == 0:
            return

        task_grad = self._get_grad()
        task_dot_mem = self._grad_dot_product(task_grad, self.mem_grad)

        # if current grad does not hurt avg mem performance then do nothing
        if task_dot_mem.item() >= 0:
            return

        # project grad so that it does not hurt average mem performance
        mem_dot_mem = self._grad_dot_product(self.mem_grad, self.mem_grad)
        mem_dot_mem = mem_dot_mem if mem_dot_mem >= 0.000001 else 0.000001
        l = task_dot_mem/mem_dot_mem
        for name in task_grad:
            task_grad[name] += l*self.mem_grad[name]
        self._save_grad(task_grad)

    def _save_grad(self, grad):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad = grad[name]

    def _get_grad(self):
        param_grad = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param_grad[name] = param.grad.detach().clone()
        return param_grad

    def _grad_dot_product(self, a, b):
        res = torch.zeros(1, device=self.device)
        for name in a:
            res += torch.sum(a[name]*b[name])
        return res

    def _update_memory(self, data):
        self.mem.append(data)

    def before_batch_calc(self):
        self._calc_mem_grad()

    def after_loss_calc(self):
        self._grad_project()

    def at_end_of_task(self):
        # the sampling procedure currently ony works for shifting window streams
        # should update other streams if you want to use A-GEM with them
        self._update_memory((self.nullClasses,
                             random.sample(self.task_stream.window_data[self.task_id], self.mem_samples_per_task)))

